# src/utils/analysis.py
import numpy as np

def two_segment_bic(log_s, y, min_seg: int = 2, eps: float = 1e-9):
    """
    Piecewise-linear 2-segment change-point with BIC scoring.
    Robust to piecewise-constant steps: if either segment has zero residual,
    we clamp RSS to a tiny epsilon so BIC stays finite and the split is accepted.

    Args:
        log_s: 1D array-like of log(separation) values (sorted by s).
        y:     1D array-like of PSI (or any monotone-ish response).
        min_seg: minimum points per segment (default 2 for 5-point sweeps).
        eps:   small positive floor for RSS to avoid -inf BIC (default 1e-9).

    Returns:
        dict with keys {"k","bic","b1","a1","b2","a2"} for the best split,
        or None if no valid split is found.
    """
    x = np.asarray(log_s, dtype=float)
    y = np.asarray(y, dtype=float)
    n = len(x)
    if n < 2*min_seg:
        return None
    # If fully flat, nothing to split
    if np.allclose(y, y[0]):
        return None

    best = None
    k_params = 4  # two slopes + two intercepts
    for k in range(min_seg, n - min_seg + 1):
        x1, y1 = x[:k], y[:k]
        x2, y2 = x[k:], y[k:]

        # fit y = b*x + a by least squares for each segment
        A1 = np.vstack([x1, np.ones_like(x1)]).T
        A2 = np.vstack([x2, np.ones_like(x2)]).T
        b1, a1 = np.linalg.lstsq(A1, y1, rcond=None)[0]
        b2, a2 = np.linalg.lstsq(A2, y2, rcond=None)[0]

        y1_hat = b1 * x1 + a1
        y2_hat = b2 * x2 + a2

        rss1 = float(np.sum((y1 - y1_hat) ** 2))
        rss2 = float(np.sum((y2 - y2_hat) ** 2))
        rss  = rss1 + rss2

        # Clamp RSS so BIC is finite even for perfect step fits
        rss = max(rss, eps)

        bic = n * np.log(rss / n) + k_params * np.log(n)

        if (best is None) or (bic < best["bic"]):
            best = {"k": k, "bic": float(bic),
                    "b1": float(b1), "a1": float(a1),
                    "b2": float(b2), "a2": float(a2)}
    return best


def simple_regression(x, y):
    """
    Simple least-squares line fit y = b*x + a with R^2.
    """
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    A = np.vstack([x, np.ones_like(x)]).T
    b, a = np.linalg.lstsq(A, y, rcond=None)[0]
    y_hat = b * x + a
    ss_tot = float(np.sum((y - y.mean()) ** 2))
    ss_res = float(np.sum((y - y_hat) ** 2))
    # Handle perfectly flat y (ss_tot==0) gracefully
    R2 = 1.0 - (ss_res / max(ss_tot, 1e-12))
    return {"slope": float(b), "intercept": float(a), "R2": float(R2)}
